{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(page-parameter-learning)=\n",
    "\n",
    "# Parameter learning in Spiking Neural Networks\n",
    "Author: Christian Pehle\n",
    "\n",
    "As we have already seen in {ref}`page-spiking` Neuron models come with parameters, such as membrane time constants, which determine their dynamics. While those parameters are often treated as arbitrarily treated constants and therefore *hyperparameters* in a machine learning context, more recently it has become clear that it can be benefitial to treat them as *parameters*. In this notebook, we will first learn how to initialise a network of LIF neurons with distinct, but fixed time constants and voltage thresholds and then how to incorporate them into the optimisation. As we will see, this is largely facilitated by the pre-existing ways of treating parameters in PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import norse\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining a Network of LIF Neurons with varying membrane time-constants\n",
    "\n",
    "A population of recurrently connected LIF neurons can be instantiated in Norse as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LIFRecurrentCell(input_size=200, hidden_size=100, p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(100.)), autapses=False, dt=0.001)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from norse.torch.module import LIFRecurrentCell\n",
    "\n",
    "m = LIFRecurrentCell(input_size=200, hidden_size=100)\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see the `LIFParameters` are initialised to some default values. It is easy enough to instead sample these values from a given distribution. While we will use a random normal distribution here, a more appropriate choice would be for example a lognormal distribution, as this would guarantee that the inverse membrane time constant remains positive, which is an essential requirement. Here we just choose the standard deviation in such a way, that a negative value would be very unlikely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, '$\\\\tau_{m}^{-1}$ [ms]')"
      ]
     },
     "execution_count": 236,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAENCAYAAAAFcn7UAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPFUlEQVR4nO3df7DldV3H8ecrVg3tB+AuhIvTRUUNEVfajKQchSgIR6wZZzBtKJmhHC1rslp0pmz6wzUt84/SIUGYMNQxfzBSKAONJpW2GAqECOqGuyK76JCam4q8++N8Ga/L3t17zzn3nnvePh8zd875fs73nM/7Pbv3dc98zvl+v6kqJEn9/MCsC5AkrQ4DXpKaMuAlqSkDXpKaMuAlqSkDXpKa2rCWk23cuLEWFhbWckpJmns33njjvVW1aaXPW9OAX1hYYMeOHWs5pSTNvST/Pc7zXKKRpKYMeElqyoCXpKYMeElqyoCXpKYMeElqyoCXpKYMeElqak0PdJIO5bTt17P7vn1rPu/mIw7nhm2nr/m80moy4LWu7L5vHzu3n7Pm8y5su3rN55RWm0s0ktSUAS9JTRnwktSUAS9JTRnwktSUAS9JTRnwktSU34OXZsyDu7RaDHhpxjy4S6vFJRpJasqAl6SmDHhJasqAl6SmDhnwSR6b5J+T3Jbk1iSvGMaPSnJtkjuG2yNXv1xJ0nIt5x38/cDvV9VPAKcCL0tyIrANuK6qTgCuG7YlSevEIQO+qu6uqk8M978G3AZsBs4FLh92uxx4/irVKEkaw4rW4JMsAE8HPgYcU1V3w+iPAHD01KuTJI1t2QGf5IeAfwB+t6q+uoLnXZhkR5Ide/fuHadGSdIYlhXwSR7GKNzfXlXvGYbvSXLs8PixwJ4DPbeqLq6qrVW1ddOmTdOoWZK0DMv5Fk2AS4DbquovFz10FXD+cP984P3TL0+SNK7lnIvmNODXgJuT3DSMvQrYDrwryQXAXcALVqVCSdJYDhnwVfVRIEs8fMZ0y5EkTYtHskpSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUxtmXYDWn9O2X8/u+/bNZO7NRxw+s3kXtl09s7ml1WDA6yF237ePndvPmXUZa+qGbafPugRp6lyikaSmDHhJasqAl6SmDHhJasqAl6SmDHhJasqAl6SmDHhJasqAl6SmDHhJasqAl6SmDHhJasqAl6SmDhnwSS5NsifJLYvGXpNkd5Kbhp9fWt0yJUkrtZx38JcBZx1g/I1VtWX4+cfpliVJmtQhA76qPgJ8ZQ1qkSRN0SRr8C9P8qlhCefIqVUkSZqKcQP+zcDjgS3A3cBfLLVjkguT7EiyY+/evWNOJ0laqbECvqruqarvVNUDwN8CzzjIvhdX1daq2rpp06Zx65QkrdBYAZ/k2EWbvwzcstS+kqTZOORFt5NcCTwb2JhkF/AnwLOTbAEK2An85uqVKEkaxyEDvqpeeIDhS1ahFknSFHkkqyQ1ZcBLUlMGvCQ1ZcBLUlOH/JBVUk+bjzichW1Xz2zuG7adPpO5v58Y8NL3qVkG7Kz+sHy/cYlGkpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpo6ZMAnuTTJniS3LBo7Ksm1Se4Ybo9c3TIlSSu1nHfwlwFn7Te2Dbiuqk4Arhu2JUnryCEDvqo+Anxlv+FzgcuH+5cDz59uWZKkSY27Bn9MVd0NMNwevdSOSS5MsiPJjr179445nSRppVb9Q9aquriqtlbV1k2bNq32dJKkwbgBf0+SYwGG2z3TK0mSNA3jBvxVwPnD/fOB90+nHEnStCzna5JXAv8GPCnJriQXANuBM5PcAZw5bEuS1pENh9qhql64xENnTLkWSdIUeSSrJDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUxsmeXKSncDXgO8A91fV1mkUJUma3EQBP3hOVd07hdeRJE2RSzSS1NSkAV/Ah5LcmOTCA+2Q5MIkO5Ls2Lt374TTSZKWa9KAP62qTgHOBl6W5Fn771BVF1fV1qraumnTpgmnkyQt10QBX1VfHG73AO8FnjGNoiRJkxs74JM8KskPP3gf+AXglmkVJkmazCTfojkGeG+SB1/n76vqmqlUJUma2NgBX1WfA542xVokSVPk1yQlqSkDXpKaMuAlqSkDXpKaMuAlqSkDXpKaMuAlqSkDXpKaMuAlqSkDXpKaMuAlqSkDXpKaMuAlqalpXHRbklZk8xGHs7Dt6pnMe8O209d83lkx4CWtuVmF7Cz+qMySSzSS1JQBL0lNGfCS1JQBL0lNGfCS1JQBL0lNGfCS1JTfg1+G07Zfz+779s26jDWz+YjDZ12CpCkw4Jdh93372Ln9nFmXIUkr4hKNJDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSUwa8JDVlwEtSU3NzoNMsjyb1yE6ph1ldKvDBudf6SlZzE/AeTSppUrO8Huss/rC4RCNJTRnwktSUAS9JTRnwktTURAGf5Kwktye5M8m2aRUlSZrc2AGf5DDgr4GzgROBFyY5cVqFSZImM8k7+GcAd1bV56rqW8A7gHOnU5YkaVKTfA9+M/CFRdu7gJ/ef6ckFwIXDptfT3L7uBPmdeM+E4CNwL0TvcL60KUP6NOLfaw/67KXMTLswT5+fJz5Jgn4HGCsHjJQdTFw8QTzTEWSHVW1ddZ1TKpLH9CnF/tYf7r0MmkfkyzR7AIeu2j7OOCLE7yeJGmKJgn4/wBOSHJ8kocD5wFXTacsSdKkxl6iqar7k7wc+CBwGHBpVd06tcqmb+bLRFPSpQ/o04t9rD9depmoj1Q9ZNlcktSAR7JKUlMGvCQ11TbgkxyW5D+TfGDYPirJtUnuGG6PnHWNy5HkiCTvTvLpJLcl+Zl57CXJ7yW5NcktSa5M8oPz0keSS5PsSXLLorEla09y0XD6jtuT/OJsqn6oJfp4/fB/61NJ3pvkiEWPzU0fix57ZZJKsnHR2LrsA5buJclvD/XemuTPF42vqJe2AQ+8Arht0fY24LqqOgG4btieB28CrqmqJwNPY9TTXPWSZDPwO8DWqjqJ0Yfy5zE/fVwGnLXf2AFrH07XcR7wlOE5fzOc1mM9uIyH9nEtcFJVnQx8BrgI5rIPkjwWOBO4a9HYeu4DDtBLkucwOivAyVX1FOANw/iKe2kZ8EmOA84B3rpo+Fzg8uH+5cDz17isFUvyI8CzgEsAqupbVXUfc9gLo29sHZ5kA/BIRsdMzEUfVfUR4Cv7DS9V+7nAO6rqm1X1eeBORqf1mLkD9VFVH6qq+4fNf2d0PAvMWR+DNwJ/yPcecLlu+4Ale3kpsL2qvjnss2cYX3EvLQMe+CtG/9APLBo7pqruBhhuj55BXSv1OGAv8LZhuemtSR7FnPVSVbsZvQu5C7gb+J+q+hBz1sd+lqr9QKfw2LzGtY3rJcA/Dffnqo8kzwN2V9Un93torvoYPBH4uSQfS/LhJD81jK+4l3YBn+S5wJ6qunHWtUzBBuAU4M1V9XTgf1m/yxhLGtanzwWOBx4DPCrJi2db1apZ1ik81pskrwbuB97+4NABdluXfSR5JPBq4I8P9PABxtZlH4tsAI4ETgX+AHhXkjBGL+0CHjgNeF6SnYzOcHl6kiuAe5IcCzDc7ln6JdaNXcCuqvrYsP1uRoE/b738PPD5qtpbVd8G3gM8k/nrY7Glap+7U3gkOR94LvCi+u6BMfPUx+MZvXn45PB7fxzwiSQ/xnz18aBdwHtq5OOMViI2MkYv7QK+qi6qquOqaoHRBxLXV9WLGZ1G4fxht/OB98+oxGWrqi8BX0jypGHoDOC/mL9e7gJOTfLI4Z3IGYw+LJ63PhZbqvargPOSPCLJ8cAJwMdnUN+yJDkL+CPgeVX1jUUPzU0fVXVzVR1dVQvD7/0u4JTh92du+ljkfcDpAEmeCDyc0RklV95LVbX9AZ4NfGC4/2hG33a4Y7g9atb1LbOHLcAO4FPDP/yR89gL8KfAp4FbgL8DHjEvfQBXMvrs4NuMwuOCg9XOaLngs8DtwNmzrv8QfdzJaF33puHnLfPYx36P7wQ2rvc+DvJv8nDgiuF35RPA6eP24qkKJKmpdks0kqQRA16SmjLgJakpA16SmjLgJakpA16SmjLg1V6SxyW5JMm7Z12LtJYMeLVXVZ+rqguWejzJQpJ9SW6adK4khye5Kcm3Fp+TXJqFsS+6La03SZ4KvHa/4ZfUd0+3ejCfraotk9ZQVfuALcM5UaSZMuA1d4bz5H+Y0SHdxzO6UMX/Ac+squdO4fUXgGuAjzI6o98ngbcxOt3C0cCLgFuBdzE64dNhwJ9V1TsnnVuaJpdoNHeq6qs1On3ybwDXVtWWqjq1qh440P5JHp3kLcDTk1y0zGmewOhqWicDTwZ+FfhZ4JXAqxhdUeeLVfW0Gl2l6prJupKmz3fwmmcnMXonfVBV9WXgt1b42p+vqpsBktzK6PJ8leRmYAG4GXhDktcxOqHdv6zw9aVV5zt4zbMTGZ1xbzV8c9H9BxZtPwBsqKrPAD/JKOhfm+RAF5uQZsqA1zx7DPClWUyc5DHAN6rqCkaXIzxlFnVIB+MSjebZB4FLkvx6VX14uILPO4GrGV15/l+BM4HXVNW03+k/FXh9kgcYncv7pVN+fWling9ebSQ5G3hiVb0pyfuAXwFeAHy9qq4+yPMWGK2jnzTFWnYCW6vq3mm9prRSLtGoky3AB5M8DPjy8K2akxitkx/Md4AfneaBTsDDGK3XSzPjEo06eQKj78SfzOiarwALVXXXwZ5UVV/gey9mPLYHD3SaxmtJk3KJRpKacolGkpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpoy4CWpKQNekpr6f/+ukXN21U9bAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "counts, bins = np.histogram(norse.torch.functional.lif.LIFParameters().tau_mem_inv + 20*torch.randn(100))\n",
    "plt.hist(bins[:-1], bins, weights=counts, histtype='step')\n",
    "plt.xlabel('$\\\\tau_{m}^{-1}$ [ms]')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The parameters of a PyTorch module can be accesses by invoking the `parameters` method. In the case of the `LIFRecurrentCell` those are the recurrent and input weight matrices. Note how both tensors have `requires_grad=True` and that the `LIFParameters` do not appear. This is because none of the tensors entering the `LIFParameters` have been registered as a PyTorch Parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter containing:\n",
       " tensor([[ 0.1287, -0.0703, -0.4601,  ..., -0.0114, -0.0835,  0.1185],\n",
       "         [ 0.0429,  0.0861,  0.2161,  ..., -0.3251, -0.1037, -0.0786],\n",
       "         [-0.3725, -0.0945,  0.1543,  ...,  0.1508, -0.0733,  0.2225],\n",
       "         ...,\n",
       "         [-0.0465, -0.2300, -0.0681,  ..., -0.2809, -0.1205,  0.0732],\n",
       "         [-0.3285, -0.0445,  0.1016,  ..., -0.1084, -0.0621,  0.2362],\n",
       "         [-0.0088,  0.0473, -0.2168,  ...,  0.0307,  0.2217,  0.0065]],\n",
       "        requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([[ 0.0000,  0.0420,  0.1593,  ..., -0.0059,  0.2016, -0.2879],\n",
       "         [-0.0932,  0.0000, -0.0390,  ...,  0.2117, -0.0844, -0.0938],\n",
       "         [-0.0210, -0.1170,  0.0000,  ...,  0.0709, -0.2757,  0.0067],\n",
       "         ...,\n",
       "         [-0.1806, -0.0191,  0.2278,  ...,  0.0000, -0.1108,  0.0404],\n",
       "         [-0.0023,  0.0122, -0.0125,  ...,  0.1097,  0.0000, -0.1081],\n",
       "         [ 0.0801,  0.0539, -0.0436,  ..., -0.1712, -0.2872,  0.0000]],\n",
       "        requires_grad=True)]"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(m.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The easiest way to rectify this situation is to define an additional `torch.nn.Module` and explicitely register the inverse membrane time constant and threshold as a parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 225,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ParametrizedLIFRecurrentCell(torch.nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super(ParametrizedLIFRecurrentCell, self).__init__()\n",
    "        self.tau_mem_inv = torch.nn.Parameter(norse.torch.functional.lif.LIFParameters().tau_mem_inv + 20*torch.randn(hidden_size))\n",
    "        self.v_th = torch.nn.Parameter(0.5 + 0.1 * torch.randn(hidden_size))\n",
    "        self.cell = norse.torch.module.LIFRecurrentCell(input_size=input_size, hidden_size=hidden_size, \n",
    "                                                        p = norse.torch.functional.lif.LIFParameters(\n",
    "                                                            tau_mem_inv = self.tau_mem_inv,\n",
    "                                                            v_th = self.v_th,\n",
    "                                                            alpha = 100,\n",
    "                                                        )\n",
    "                                                        \n",
    "        )\n",
    "\n",
    "    def forward(self, x, s = None):\n",
    "        return self.cell(x, s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 226,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ParametrizedLIFRecurrentCell(\n",
       "  (cell): LIFRecurrentCell(\n",
       "    input_size=200, hidden_size=100, p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=Parameter containing:\n",
       "    tensor([105.7182, 107.4838, 100.0283, 105.1591,  53.3966, 109.9805,  96.5298,\n",
       "             98.7639,  61.6523, 100.5165, 123.9547,  91.8613,  74.4375,  77.2754,\n",
       "             70.4332,  75.0951,  74.7826, 128.1384,  62.5254,  98.4896, 132.6102,\n",
       "            101.3723,  94.8174, 113.6608,  92.8799, 102.4582,  96.0418, 109.7185,\n",
       "             97.0302,  94.2958, 106.0022, 111.0281,  57.8042,  60.2294, 124.0950,\n",
       "            120.5536,  85.0079, 122.0159, 102.3311,  82.8800,  91.6850, 109.0127,\n",
       "             81.5401, 140.7032, 102.0714,  92.9130,  81.3563, 131.6322,  88.3253,\n",
       "             80.4107,  90.0314, 111.8495,  78.5712,  61.6052,  91.2183, 142.2331,\n",
       "            100.1681,  97.8471,  90.6644, 115.2660,  89.4297,  85.5509,  27.3883,\n",
       "             44.9216,  82.8977, 130.8837, 115.9058, 110.5285,  65.6668, 100.6617,\n",
       "             99.4103, 114.3868,  87.7616,  92.7753, 120.3395,  90.8698,  96.0968,\n",
       "             94.1020, 134.0003,  92.5420, 119.3351, 108.7738, 106.1515, 100.8721,\n",
       "             69.7877, 114.0343, 119.7976,  94.8820, 126.8868, 107.7279,  87.2416,\n",
       "            118.0816,  85.2757,  95.5625, 103.5655, 101.1446,  92.8175, 100.2220,\n",
       "            101.5810, 120.5847], requires_grad=True), v_leak=tensor(0.), v_th=Parameter containing:\n",
       "    tensor([0.4763, 0.7102, 0.5170, 0.5944, 0.3752, 0.6186, 0.5279, 0.5828, 0.4731,\n",
       "            0.4420, 0.3426, 0.6271, 0.3732, 0.4154, 0.5121, 0.4326, 0.5450, 0.5674,\n",
       "            0.4952, 0.6068, 0.5628, 0.3578, 0.3267, 0.6845, 0.6106, 0.5249, 0.5779,\n",
       "            0.4979, 0.5589, 0.5879, 0.5523, 0.4543, 0.4563, 0.5968, 0.5322, 0.6053,\n",
       "            0.6176, 0.6008, 0.4561, 0.4256, 0.4695, 0.4654, 0.5881, 0.4071, 0.6236,\n",
       "            0.6023, 0.5724, 0.6621, 0.5386, 0.4352, 0.5922, 0.6512, 0.7498, 0.3829,\n",
       "            0.4641, 0.5914, 0.4768, 0.4183, 0.5869, 0.7052, 0.4787, 0.4283, 0.6367,\n",
       "            0.4590, 0.4615, 0.4977, 0.5333, 0.4441, 0.5708, 0.7478, 0.5310, 0.5089,\n",
       "            0.5937, 0.2963, 0.4295, 0.4411, 0.4158, 0.6090, 0.6406, 0.5169, 0.4072,\n",
       "            0.4794, 0.5371, 0.6913, 0.6281, 0.3540, 0.6166, 0.5970, 0.5496, 0.4816,\n",
       "            0.4926, 0.4962, 0.5401, 0.7173, 0.6444, 0.3962, 0.5445, 0.3929, 0.5025,\n",
       "            0.3769], requires_grad=True), v_reset=tensor(0.), method='super', alpha=tensor(100)), autapses=False, dt=0.001\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 226,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = ParametrizedLIFRecurrentCell(200, 100)\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The inverse membrane time constant and the threshold value now also appears as parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter containing:\n",
       " tensor([105.7182, 107.4838, 100.0283, 105.1591,  53.3966, 109.9805,  96.5298,\n",
       "          98.7639,  61.6523, 100.5165, 123.9547,  91.8613,  74.4375,  77.2754,\n",
       "          70.4332,  75.0951,  74.7826, 128.1384,  62.5254,  98.4896, 132.6102,\n",
       "         101.3723,  94.8174, 113.6608,  92.8799, 102.4582,  96.0418, 109.7185,\n",
       "          97.0302,  94.2958, 106.0022, 111.0281,  57.8042,  60.2294, 124.0950,\n",
       "         120.5536,  85.0079, 122.0159, 102.3311,  82.8800,  91.6850, 109.0127,\n",
       "          81.5401, 140.7032, 102.0714,  92.9130,  81.3563, 131.6322,  88.3253,\n",
       "          80.4107,  90.0314, 111.8495,  78.5712,  61.6052,  91.2183, 142.2331,\n",
       "         100.1681,  97.8471,  90.6644, 115.2660,  89.4297,  85.5509,  27.3883,\n",
       "          44.9216,  82.8977, 130.8837, 115.9058, 110.5285,  65.6668, 100.6617,\n",
       "          99.4103, 114.3868,  87.7616,  92.7753, 120.3395,  90.8698,  96.0968,\n",
       "          94.1020, 134.0003,  92.5420, 119.3351, 108.7738, 106.1515, 100.8721,\n",
       "          69.7877, 114.0343, 119.7976,  94.8820, 126.8868, 107.7279,  87.2416,\n",
       "         118.0816,  85.2757,  95.5625, 103.5655, 101.1446,  92.8175, 100.2220,\n",
       "         101.5810, 120.5847], requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([0.4763, 0.7102, 0.5170, 0.5944, 0.3752, 0.6186, 0.5279, 0.5828, 0.4731,\n",
       "         0.4420, 0.3426, 0.6271, 0.3732, 0.4154, 0.5121, 0.4326, 0.5450, 0.5674,\n",
       "         0.4952, 0.6068, 0.5628, 0.3578, 0.3267, 0.6845, 0.6106, 0.5249, 0.5779,\n",
       "         0.4979, 0.5589, 0.5879, 0.5523, 0.4543, 0.4563, 0.5968, 0.5322, 0.6053,\n",
       "         0.6176, 0.6008, 0.4561, 0.4256, 0.4695, 0.4654, 0.5881, 0.4071, 0.6236,\n",
       "         0.6023, 0.5724, 0.6621, 0.5386, 0.4352, 0.5922, 0.6512, 0.7498, 0.3829,\n",
       "         0.4641, 0.5914, 0.4768, 0.4183, 0.5869, 0.7052, 0.4787, 0.4283, 0.6367,\n",
       "         0.4590, 0.4615, 0.4977, 0.5333, 0.4441, 0.5708, 0.7478, 0.5310, 0.5089,\n",
       "         0.5937, 0.2963, 0.4295, 0.4411, 0.4158, 0.6090, 0.6406, 0.5169, 0.4072,\n",
       "         0.4794, 0.5371, 0.6913, 0.6281, 0.3540, 0.6166, 0.5970, 0.5496, 0.4816,\n",
       "         0.4926, 0.4962, 0.5401, 0.7173, 0.6444, 0.3962, 0.5445, 0.3929, 0.5025,\n",
       "         0.3769], requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([[ 0.2647,  0.0369, -0.1230,  ..., -0.0659,  0.0242,  0.1555],\n",
       "         [-0.2033,  0.0797,  0.0882,  ...,  0.2867,  0.1669, -0.0143],\n",
       "         [ 0.1000,  0.0366, -0.3470,  ...,  0.0849, -0.0559,  0.0878],\n",
       "         ...,\n",
       "         [ 0.0844, -0.1082, -0.0416,  ...,  0.1089,  0.0370, -0.1723],\n",
       "         [-0.1273, -0.2060, -0.0552,  ...,  0.3023, -0.1878, -0.1071],\n",
       "         [-0.0659, -0.0088,  0.0081,  ..., -0.2171, -0.1173, -0.2679]],\n",
       "        requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([[ 0.0000e+00,  4.2299e-02,  4.7312e-02,  ..., -1.1504e-02,\n",
       "          -1.9836e-01, -1.2699e-01],\n",
       "         [-1.5662e-01,  0.0000e+00, -3.6343e-02,  ...,  5.3085e-02,\n",
       "           9.1224e-03, -1.6341e-01],\n",
       "         [-3.6074e-03,  3.5989e-02,  0.0000e+00,  ..., -5.7264e-02,\n",
       "           9.8241e-02, -1.4665e-01],\n",
       "         ...,\n",
       "         [-3.4692e-02,  7.1811e-04, -3.9618e-02,  ...,  0.0000e+00,\n",
       "          -1.8011e-01,  1.4050e-01],\n",
       "         [ 2.6199e-02, -2.8831e-02, -1.3311e-01,  ..., -7.4930e-02,\n",
       "           0.0000e+00, -2.1475e-02],\n",
       "         [ 2.4574e-01, -3.2047e-02, -1.4999e-04,  ..., -1.6092e-01,\n",
       "          -3.9056e-02,  0.0000e+00]], requires_grad=True)]"
      ]
     },
     "execution_count": 227,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(m.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Toy Example: Training a small recurrent SNN on MNIST\n",
    "\n",
    "In order to demonstrate training on an example task, we turn to MNIST. The training and testing code here is adopted from  the training notebook on MNIST and serves as an illustration. We do not necessarily expect there to be a great benefit in optimising the time-constants and threshold in this example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac90c842fa7440a79a8e630fe91df4d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9912422.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0b94ffe103eb492cb4e7f93443d7fba3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28881.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7984a139d1864317a06e56f74d411aad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1648877.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e3b1a913d41842d7aaca7924df6d9d5a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4542.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "transform = torchvision.transforms.Compose(\n",
    "    [\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize((0.1307,), (0.3081,)),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_data = torchvision.datasets.MNIST(\n",
    "    root=\".\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=transform,\n",
    ")\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_data,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    torchvision.datasets.MNIST(\n",
    "        root=\".\",\n",
    "        train=False,\n",
    "        transform=transform,\n",
    "    ),\n",
    "    batch_size=BATCH_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Defitinition\n",
    "\n",
    "For simplicity we are considering a simple network with a single hidden layer of recurrently connected neurons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "from norse.torch import LICell\n",
    "\n",
    "class SNN(torch.nn.Module):\n",
    "    def __init__(self, \n",
    "                 input_features,\n",
    "                 hidden_features, \n",
    "                 output_features,\n",
    "                 recurrent_cell\n",
    "                ):\n",
    "        super(SNN, self).__init__()\n",
    "        self.cell = recurrent_cell\n",
    "        self.fc_out = torch.nn.Linear(hidden_features, output_features, bias=False)\n",
    "        self.out = LICell()\n",
    "        self.input_features = input_features\n",
    "                             \n",
    "    def forward(self, x):\n",
    "        seq_length, batch_size, _, _, _ = x.shape\n",
    "        s1 = so = None\n",
    "        voltages = []\n",
    "\n",
    "\n",
    "        for ts in range(seq_length):\n",
    "            z = x[ts, :, :, :].view(-1, self.input_features)\n",
    "            z, s1 = self.cell(z, s1)\n",
    "            z = self.fc_out(z)\n",
    "            vo, so = self.out(z, so)\n",
    "            voltages += [vo]\n",
    "        \n",
    "        return torch.stack(voltages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(torch.nn.Module):\n",
    "    def __init__(self, encoder, snn, decoder):\n",
    "        super(Model, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.snn = snn\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.snn(x)\n",
    "        log_p_y = self.decoder(x)\n",
    "        return log_p_y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training and Test loop\n",
    "\n",
    "Both the training and test loop are independent of the model or dataset that we are using, in practice you should however probably use a more sophisticated version as they don't take care of parameter checkpoints or other concerns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "def train(model, device, train_loader, optimizer, epoch, max_epochs):\n",
    "    model.train()\n",
    "    losses = []\n",
    "\n",
    "    for (data, target) in tqdm(train_loader, leave=False):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = torch.nn.functional.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.item())\n",
    "\n",
    "    mean_loss = np.mean(losses)\n",
    "    return losses, mean_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader, epoch):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += torch.nn.functional.nll_loss(\n",
    "                output, target, reduction=\"sum\"\n",
    "            ).item()  # sum up batch loss\n",
    "            pred = output.argmax(\n",
    "                dim=1, keepdim=True\n",
    "            )  # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    accuracy = 100.0 * correct / len(test_loader.dataset)\n",
    "\n",
    "    return test_loss, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "from norse.torch import ConstantCurrentLIFEncoder\n",
    "\n",
    "def decode(x):\n",
    "    x, _ = torch.max(x, 0)\n",
    "    log_p_y = torch.nn.functional.log_softmax(x, dim=1)\n",
    "    return log_p_y\n",
    "\n",
    "\n",
    "T = 32\n",
    "LR = 0.002\n",
    "INPUT_FEATURES = 28*28\n",
    "HIDDEN_FEATURES = 100\n",
    "OUTPUT_FEATURES = 10\n",
    "EPOCHS = 5\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    DEVICE = torch.device(\"cuda\")\n",
    "else:\n",
    "    DEVICE = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_training(model, optimizer, epochs = EPOCHS):\n",
    "    training_losses = []\n",
    "    mean_losses = []\n",
    "    test_losses = []\n",
    "    accuracies = []\n",
    "\n",
    "    torch.autograd.set_detect_anomaly(True)\n",
    "\n",
    "    for epoch in trange(epochs):\n",
    "        training_loss, mean_loss = train(model, DEVICE, train_loader, optimizer, epoch, max_epochs=EPOCHS)\n",
    "        test_loss, accuracy = test(model, DEVICE, test_loader, epoch)\n",
    "        training_losses += training_loss\n",
    "        mean_losses.append(mean_loss)\n",
    "        test_losses.append(test_loss)\n",
    "        accuracies.append(accuracy)\n",
    "\n",
    "    print(f\"final accuracy: {accuracies[-1]}\")\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Evaluation\n",
    "\n",
    "With all of this boilerplate out of the way we can define our final model and run training for a number of epochs (10 by default). If you are impatient you can decrease that number to 1-2. With the default values, the result will be a test accuracy of >95%, that is nothing to write home about. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(\n",
    "    encoder=ConstantCurrentLIFEncoder(\n",
    "      seq_length=T,\n",
    "    ),\n",
    "    snn=SNN(\n",
    "      input_features=INPUT_FEATURES,\n",
    "      hidden_features=HIDDEN_FEATURES,\n",
    "      output_features=OUTPUT_FEATURES,\n",
    "      recurrent_cell=ParametrizedLIFRecurrentCell(\n",
    "            input_size=28*28, \n",
    "            hidden_size=100\n",
    "      )\n",
    "    ),\n",
    "    decoder=decode\n",
    ").to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_mem_inv_before = model.snn.cell.cell.p.tau_mem_inv.cpu().detach().numpy()\n",
    "v_th_before = model.snn.cell.cell.p.v_th.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 230,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c40af3ede50643159ee27af2a3f3fcf2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "final accuracy: 95.94\n"
     ]
    }
   ],
   "source": [
    "model_after = run_training(model, optimizer, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "metadata": {},
   "outputs": [],
   "source": [
    "tau_mem_inv_after = model_after.snn.cell.cell.p.tau_mem_inv.cpu().detach().numpy()\n",
    "v_th_after = model_after.snn.cell.cell.p.v_th.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can plot the distribution of inverse membrane timeconstants before and after the training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fe665b8efd0>"
      ]
     },
     "execution_count": 233,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhgAAAFECAYAAABlIEcBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXCElEQVR4nO3df5CU9Z3g8fcn4AoqggomBCTD7uqKIgFvYhExliRn0LjxR6pylbC5mDNZ9GrXRHLWHa6V3bGuKsvVetGYH5syJ2rKxJV4C0kd+YG/EDBoDgIRDBo0zuEgBiQnSgQj+rk/uuEIzMD8+M480zPvV5U1008/3f1Jfyvwpp/upyMzkSRJKukdVQ8gSZIGHgNDkiQVZ2BIkqTiDAxJklScgSFJkoozMCRJUnFD+/LBRo8enU1NTX35kJIkqZesWbPm5cwc0951fRoYTU1NrF69ui8fUpIk9ZKI+D8dXechEkmSVJyBIUmSijMwJElScX36HgxJkhrFm2++SVtbG3v27Kl6lMoNGzaM8ePHc9RRR3X6NgaGJEntaGtrY8SIETQ1NRERVY9Tmcxkx44dtLW1MXHixE7fzkMkkiS1Y8+ePZx00kmDOi4AIoKTTjqpy6/kGBiSJHVgsMfFPt15HgwMSZL6qdbWViZPntzp/Z9++mmmTp3KtGnTeO6553pxsiPzPRiSJHXCjPkPs+WV3cXub9yo4Tw274PF7g9g8eLFXHbZZdx0002d2j8zyUze8Y7yrzcYGJIkdcKWV3bTOv+SYvfXNG9Jp/bbu3cvV155JWvXruW0007jO9/5Dhs3buSLX/wiu3btYvTo0dx1112sXbuWW2+9lSFDhrB8+XIeeeQRvvKVr7BgwQIAPve5z3HdddfR2trKxRdfzMyZM1m1ahWLFy9m4cKFLFy4kDfeeIMrrrii04FyOB4ikSSpH3vmmWeYM2cOTz75JMcffzzf+MY3uPbaa7n//vtZs2YNV111FTfeeCMf+chHuOaaa5g7dy6PPPIIa9as4c477+SJJ57g8ccf59vf/jZr167df5+f/vSnWbt2Lc888wybNm3i5z//OevWrWPNmjUsX768x3P7CobUj82Y/zD37f5rxsfLVY/SdSMnwNz1VU8hNbxTTjmFGTNmAPCpT32KL3/5y2zYsIELL7wQgLfeeouxY8cecruVK1dyxRVXcOyxxwLwsY99jBUrVnDppZfynve8h+nTpwOwdOlSli5dyrRp0wDYtWsXmzZt4vzzz+/R3AaG1I9teWU344e9DC07qx6lS5rmLaGV2VWPIQ0IB3+CY8SIEZx55pmsWrXqsLfLzA6v2xcd+/a74YYbuPrqq3s26EE8RCJJUj+2efPm/TFx7733Mn36dLZv375/25tvvslTTz11yO3OP/98Fi9ezOuvv87vf/97Fi1axAc+8IFD9ps1axYLFixg165dAGzZsoVt27b1eG5fwZAkqR+bNGkSd999N1dffTWnnnoq1157LbNmzeLzn/88O3fuZO/evVx33XWceeaZf3S7s88+m8985jOcc845QO1NntOmTaO1tfWP9vvwhz/Mxo0bef/73w/Acccdxz333MPJJ5/co7njcC+hlNbc3JyrV6/us8eTGl3TvCW0DpvdmIdIGnBu6UAbN25k0qRJ+y83wsdUe9PBzwdARKzJzOb29vcVDEmSOqGRYqA/8D0YkiSpOANDkiQVZ2BIkqTiDAxJklScgSFJkoozMCRJajDf//73mTRpEjNnzmTZsmX87Gc/q3qkQ/gxVUmSOuOWs2Dn5nL314Pv67njjjv45je/ycyZM2lpaeG4447j3HPP7fTt9+7dy9ChvZsABoYkSZ2xc3PZk8e1jOzUbpdffjkvvPACe/bs4Qtf+AIvvfQSK1eu5Pnnn2fKlCmsWLGCIUOGcM899/C1r32N008/nWuuuYbNm2sxdOuttzJjxgxaWlp48cUXaW1tZfTo0Xzve98r97+lHQaGJEn92IIFCzjxxBPZvXs373vf+3j00Ud5+OGHufnmm2lubt7/Csb1118PwOzZs5k7dy7nnXcemzdvZtasWWzcuBGANWvWsHLlSoYPH97rcxsYkiT1Y7fddhuLFi0C4IUXXmDTpk2H3f/BBx/kV7/61f7Lr776Kq+99hoAl156aZ/EBRgYkiT1W8uWLePBBx9k1apVHHPMMVxwwQXs2bPnsLd5++23WbVqVbshceDXtPc2P0UiSVI/tXPnTk444QSOOeYYnn76aR5//PFD9hkxYsT+Vyig9u2oX//61/dfXrduXV+MeggDQ5Kkfuqiiy5i7969TJkyhS996UtMnz79kH0++tGPsmjRIqZOncqKFSu47bbbWL16NVOmTOGMM87gW9/6VgWTe4hEkqTOGTmh05/86PT9HcHRRx/Nj3/840O2L1u2bP/vp512Gk8++eQfXX/fffcdcpuWlpYuj9gTRwyMiDgF+A7wLuBt4PbM/GpEnAjcBzQBrcC/y8z/23ujSpJUoW6es2Kw6swhkr3Af8rMScB04G8i4gxgHvBQZp4KPFS/LEmSdOTAyMytmfmL+u+vARuBccBlwN313e4GLu+lGSVJUoPp0ps8I6IJmAY8AbwzM7dCLUKAk4tPJ0lShTKz6hH6he48D50OjIg4DvifwHWZ+WoXbjcnIlZHxOrt27d3eUBJkqowbNgwduzYMegjIzPZsWMHw4YN69LtOvUpkog4ilpcfDcz/7W++bcRMTYzt0bEWGBbB4PdDtwO0NzcPLhXSZLUMMaPH09bWxv+47gWW+PHj+/SbTrzKZIA7gA2ZuZXDrjqh8CVwPz6zx906ZElSerHjjrqKCZOnFj1GA2rM69gzAD+PbA+ItbVt/0dtbBYGBGfBTYDH++VCSVJUsM5YmBk5kogOrj6Q2XHkSRJA4GnCpckScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqbgjBkZELIiIbRGx4YBtLRGxJSLW1f/7SO+OKUmSGklnXsG4C7ione23ZObU+n8/KjuWJElqZEcMjMxcDvyuD2aRJEkDRE/eg/G3EfFk/RDKCcUmkiRJDa+7gfHPwJ8BU4GtwH/vaMeImBMRqyNi9fbt27v5cJIkqZF0KzAy87eZ+VZmvg18GzjnMPvenpnNmdk8ZsyY7s4pSZIaSLcCIyLGHnDxCmBDR/tKkqTBZ+iRdoiIe4ELgNER0Qb8A3BBREwFEmgFru69ESVJUqM5YmBk5ifb2XxHL8wiSZIGCM/kKUmSijMwJElScQaGJEkqzsCQJEnFGRiSJKk4A0OSJBV3xI+pSu265SzYubnqKbqsLUdz3hu3VT1Gp40bNRz2VD1F140bNZy23aMZ3zKy6lG6ZuQEmLu+6imkAcHAUPfs3AwtO6ueokua5i2hddhsWudfUvUoXdNS9QBd99i8DwLPVT1GlzTNW0Irs6seQxowPEQiSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFTe06gHUeGbMf5jHgKZ5S6oepUvGjRoOR0+AlpFVj9I1IydUPYEkdZmBoS7b8spuGAat8y+pepRuWF/1AJI0KHiIRJIkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFXfEwIiIBRGxLSI2HLDtxIh4ICI21X+e0LtjSpKkRtKZVzDuAi46aNs84KHMPBV4qH5ZkiQJ6ERgZOZy4HcHbb4MuLv++93A5WXHkiRJjay778F4Z2ZuBaj/PLncSJIkqdEN7e0HiIg5wByACRMm9O6D3XIW7Nzcu49R2sgJMHd91VNIg964UcNp2z2a8S0jqx6l6/xzRP1QdwPjtxExNjO3RsRYYFtHO2bm7cDtAM3NzdnNx+ucnZuhZWevPkRxjfiHmTQAPTbvg8BzVY/RZU3zltDK7KrHkA7R3UMkPwSurP9+JfCDMuNIkqSBoDMfU70XWAX8RUS0RcRngfnAhRGxCbiwflmSJAnoxCGSzPxkB1d9qPAskiRpgPBMnpIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSqu17+LpK/MmP8wj1E7bW6jGDdqOI9VPYSkhtao36GylTGMbXm26jHUiwZMYGx5ZTcMg9b5l1Q9Sqc1zVsCw6qeQlIja9TvUBnbYEGkrvMQiSRJKs7AkCRJxRkYkiSpOANDkiQVZ2BIkqTiDAxJklScgSFJkoozMCRJUnEGhiRJKs7AkCRJxRkYkiSpOANDkiQVZ2BIkqTiDAxJklScgSFJkoozMCRJUnEGhiRJKs7AkCRJxRkYkiSpOANDkiQVZ2BIkqTiDAxJklTc0KoHGMzGjRoOe6Bp3pKqR+mSfXNLktQRA6NCj837ILRA6/xLqh6l61qqHkCS1J95iESSJBVnYEiSpOIMDEmSVJyBIUmSijMwJElScQaGJEkqzsCQJEnFGRiSJKk4A0OSJBXnmTyrNnICtIyseoquGzmh6gkkSf2YgVG1ueurnkCSpOI8RCJJkoozMCRJUnEGhiRJKs7AkCRJxRkYkiSpOANDkiQV16OPqUZEK/Aa8BawNzObSwwlSZIaW4nzYMzMzJcL3I8kSRogPEQiSZKK62lgJLA0ItZExJwSA0mSpMbX00MkMzLzxYg4GXggIp7OzOUH7lAPjzkAEyb4/RWSJA0GPXoFIzNfrP/cBiwCzmlnn9szszkzm8eMGdOTh5MkSQ2i24EREcdGxIh9vwMfBjaUGkySJDWunhwieSewKCL23c/3MvMnRaaSJEkNrduBkZm/Ad5bcBZJkjRA+DFVSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqbihVQ8gSVIjmDH/Ye7b/deMj5erHqVLtjKGsS3P9vnjGhiSJHXClld2M37Yy9Cys+pRumRsy8hKHtdDJJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFeepwiVJfa72/RjVnMK6u1qHASMnVD1GwzAwJEl9roov31Lf8hCJJEkqzsCQJEnFGRiSJKk4A0OSJBVnYEiSpOIMDEmSVFyPAiMiLoqIZyLi2YiYV2ooSZLU2LodGBExBPgGcDFwBvDJiDij1GCSJKlx9eQVjHOAZzPzN5n5B+BfgMvKjCVJkhpZTwJjHPDCAZfb6tskSdIg15NThUc72/KQnSLmAHPqF3dFxDM9eMwjD3RTAIwGXu6tx1GXuR79j2vSv7ge/cvAW4+b2vsru4j3dHRFTwKjDTjlgMvjgRcP3ikzbwdu78HjdFlErM7M5r58THXM9eh/XJP+xfXoX1yPMnpyiOR/A6dGxMSI+BPgE8APy4wlSZIaWbdfwcjMvRHxt8BPgSHAgsx8qthkkiSpYfXo69oz80fAjwrNUlKfHpLREbke/Y9r0r+4Hv2L61FAZB7yvkxJkqQe8VThkiSpuAERGBExKiLuj4inI2JjRLw/Ik6MiAciYlP95wlVzzlYRMTciHgqIjZExL0RMcz16DsRsSAitkXEhgO2dfj8R8QN9dP9PxMRs6qZeuDqYD3+qf7n1ZMRsSgiRh1wnevRi9pbjwOuuz4iMiJGH7DN9eimAREYwFeBn2Tm6cB7gY3APOChzDwVeKh+Wb0sIsYBnweaM3MytTcAfwLXoy/dBVx00LZ2n//66f0/AZxZv803618DoHLu4tD1eACYnJlTgF8DN4Dr0Ufu4tD1ICJOAS4ENh+wzfXogYYPjIg4HjgfuAMgM/+Qma9QO2353fXd7gYur2K+QWooMDwihgLHUDs/iuvRRzJzOfC7gzZ39PxfBvxLZr6Rmc8Dz1L7GgAV0t56ZObSzNxbv/g4tfMIgevR6zr4/wfALcB/5o9PGOl69EDDBwbwp8B24M6IWBsR/yMijgXemZlbAeo/T65yyMEiM7cAN1P7V8BWYGdmLsX1qFpHz7+n/K/eVcCP67+7HhWIiEuBLZn5y4Oucj16YCAExlDgbOCfM3Ma8Ht8+b0y9WP7lwETgXcDx0bEp6qdSofRqVP+q3dExI3AXuC7+za1s5vr0Ysi4hjgRuDv27u6nW2uRycNhMBoA9oy84n65fupBcdvI2IsQP3ntormG2z+LfB8Zm7PzDeBfwXOxfWoWkfPf6dO+a/yIuJK4C+Bv8r/f74A16Pv/Rm1fxD9MiJaqT3nv4iId+F69EjDB0ZmvgS8EBF/Ud/0IeBX1E5bfmV925XADyoYbzDaDEyPiGMiIqitx0Zcj6p19Pz/EPhERBwdEROBU4GfVzDfoBIRFwH/Bbg0M18/4CrXo49l5vrMPDkzmzKziVpUnF3/u8X16IEencmzH7kW+G79O1F+A/wHavG0MCI+S+0vvY9XON+gkZlPRMT9wC+ovfS7ltpZ8Y7D9egTEXEvcAEwOiLagH8A5tPO85+ZT0XEQmpRvhf4m8x8q5LBB6gO1uMG4GjggVqH83hmXuN69L721iMz72hvX9ejZzyTpyRJKq7hD5FIkqT+x8CQJEnFGRiSJKk4A0OSJBVnYEiSpOIMDEmSVJyBIUmSijMwJHVJRPxpRNxRP6GaJLXLwJDUJZn5m8z8bEfXR0RTROyOiHU9fayIGB4R6yLiDxExuqf3J6nvDJRThUsqLCLOAv7xoM1XZWZnvqjuucyc2tMZMnM3MLX+JVSSGoiBIQ1yEXE88CjwJ9S+VfLXwB7g3Mz8ywL33wT8BFgJTAd+CdwJ3AScDPwV8BSwkNq3VQ4B/mtm3tfTx5ZUHQ+RSINcZr6amdOofUngA5k5NTOnZ+bb7e0fESdFxLeAaRFxQycf5s+BrwJTgNOB2cB5wPXA3wEXAS9m5nszczK1IJHUwHwFQ9I+k6m9knBYmbkDuKaL9/18Zq4HiIingIcyMyNiPdAErAdujoj/BvyvzFzRxfuX1M/4Coakfc4ANvTSfb9xwO9vH3D5bWBoZv4a+DfUQuMfI+Lve2kOSX3EwJC0z7uBl6p44Ih4N/B6Zt4D3AycXcUcksrxEImkfX4K3BERn8nMRyPiXcB9wBLgTOBnwIVAS2aWfqXjLOCfIuJt4E3gPxa+f0l9LDKz6hkk9UMRcTFwWmZ+NSIWAx8DPg7syswlh7ldE7X3UUwuOEsr0JyZL5e6T0m9y0MkkjoyFfhpRBwF7Kh/qmQytfdJHM5bwMiSJ9oCjqL2fg1JDcJDJJI68ufUzokxBdhY39aUmZsPd6PMfAE4pcQA+060VeK+JPUtD5FIkqTiPEQiSZKKMzAkSVJxBoYkSSrOwJAkScUZGJIkqTgDQ5IkFWdgSJKk4gwMSZJUnIEhSZKK+39wdeCBb3FwYAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 648x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "counts, bins = np.histogram(tau_mem_inv_before)\n",
    "fig, ax = plt.subplots(figsize=(9,5))\n",
    "ax.hist(bins[:-1], bins, weights=counts, histtype='step', label='before')\n",
    "counts, bins = np.histogram(tau_mem_inv_after)\n",
    "ax.hist(bins[:-1], bins, weights=counts, histtype='step', label='after')\n",
    "ax.set_xlabel('$\\\\tau_{m}^{-1}$ [ms]')\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly we can plot the distribution of membrane threshold voltages before and after optimisation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fe665b8a8e0>"
      ]
     },
     "execution_count": 234,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhgAAAFBCAYAAAA17dayAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXkUlEQVR4nO3df5SW5X3n8fc3YgIqGVAkZcHJuFutaCRgRxdDYsXUoLj+IJueTbKJpsaiu62JZN1T0tR20pzm8EcaPSRxPSQSTd3YmB+QbNSsMUgAHUyYQhAdLLGZjoMmKIZRFCyjV/+YBzoDg3Mzc81zzzPzfp3Dmef+/Z2Lm+f5cF/3c92RUkKSJCmnN5VdgCRJGnkMGJIkKTsDhiRJys6AIUmSsjNgSJKk7AwYkiQpuzHVPNikSZNSQ0NDNQ8pSZKGSEtLy/MppRP7WlbVgNHQ0MCGDRuqeUhJkjREIuJfDrfMLhJJkpSdAUOSJGVnwJAkSdlV9R4MSZJqxb59++jo6GDv3r1ll1K6sWPHMm3aNI4++ujC2xgwJEnqQ0dHB+PHj6ehoYGIKLuc0qSU2LlzJx0dHZx88smFt7OLRJKkPuzdu5cTTjhhVIcLgIjghBNOOOIrOQYMSZIOY7SHi/0G0g4GDEmShqm2tjbe8Y53FF5/69atzJw5k1mzZvHUU08NYWX98x4MSZIKmLNkFdt37cm2v6kTxvHw4guy7Q9g5cqVXH755Xz2s58ttH5KiZQSb3pT/usNBgxJkgrYvmsPbUsuyba/hsX3Flqvq6uLq666io0bN3LqqafyjW98g9bWVj71qU+xe/duJk2axB133MHGjRu55ZZbOOqoo1izZg0PPfQQX/ziF1m+fDkA11xzDTfccANtbW1cfPHFzJ07l+bmZlauXMk999zDPffcw6uvvsqCBQsKB5Q3YheJJEnD2JNPPsnChQvZvHkzb33rW/nKV77C9ddfz3e+8x1aWlq4+uqr+cxnPsP8+fO57rrrWLRoEQ899BAtLS18/etf59FHH2X9+vV89atfZePGjQf2eeWVV7Jx40aefPJJtm3bxs9+9jM2bdpES0sLa9asGXTdXsGQdGRuPhM628uuori6elj0WNlVSAN20kknMWfOHAA+8pGP8PnPf54tW7Zw4YUXAvDaa68xZcqUQ7Zbt24dCxYs4NhjjwXg/e9/P2vXruWyyy7j7W9/O7NnzwbggQce4IEHHmDWrFkA7N69m23btnHeeecNqm4DhqQj09kOTZ1lV1FcU13ZFUiDcvA3OMaPH88ZZ5xBc3PzG26XUjrssv2hY/96n/70p7n22msHV+hB7CKRJGkYa29vPxAm7r77bmbPns1zzz13YN6+fft4/PHHD9nuvPPOY+XKlbzyyiu8/PLLrFixgve85z2HrDdv3jyWL1/O7t27Adi+fTs7duwYdN1ewZAkaRibPn06d955J9deey2nnHIK119/PfPmzeMTn/gEnZ2ddHV1ccMNN3DGGWf02u6ss87iYx/7GOeccw7QfZPnrFmzaGtr67Xe+973PlpbWzn33HMBOO6447jrrruYPHnyoOqON7qEkltjY2PasGFD1Y4naQg01dVeF0kt1atho7W1lenTpx+YroWvqQ6lg9sDICJaUkqNfa3vFQxJkgqopTAwHHgPhiRJys6AIUmSsjNgSJKk7AwYkiQpOwOGJEnKzoAhSVKN+fa3v8306dOZO3cuq1ev5pFHHim7pEP4NVVJkorI/RyeQTwn5/bbb+fWW29l7ty5NDU1cdxxx/Gud72r8PZdXV2MGTO0EcCAIUlSEbmfw1PwOTlXXHEFTz/9NHv37uWTn/wkv/71r1m3bh2/+tWvmDFjBmvXruWoo47irrvu4ktf+hKnnXYa1113He3t3WHolltuYc6cOTQ1NfHMM8/Q1tbGpEmT+OY3v5nvd+mDAUOSpGFs+fLlHH/88ezZs4ezzz6bn/70p6xatYovfOELNDY2HriCceONNwLw4Q9/mEWLFvHud7+b9vZ25s2bR2trKwAtLS2sW7eOcePGDXndBgxJkoaxpUuXsmLFCgCefvpptm3b9obrP/jggzzxxBMHpl988UVeeuklAC677LKqhAswYEiSNGytXr2aBx98kObmZo455hjOP/989u7d+4bbvP766zQ3N/cZJHo+pn2o+S0SSZKGqc7OTiZOnMgxxxzD1q1bWb9+/SHrjB8//sAVCuh+OuqXv/zlA9ObNm2qRqmHMGBIkjRMXXTRRXR1dTFjxgxuuukmZs+efcg6l156KStWrGDmzJmsXbuWpUuXsmHDBmbMmMHpp5/ObbfdVkLldpFIklRMXX3hb34U3l8/3vKWt3D//fcfMn/16tUHXp966qls3ry51/Jvfetbh2zT1NR0xCUOhgFDkqQiBjhmxWhlF4kkScrOgCFJkrIzYEiSdBgppbJLGBYG0g4GDEmS+jB27Fh27tw56kNGSomdO3cyduzYI9rOmzwlSerDtGnT6Ojo4Lnnniu7lNKNHTuWadOmHdE2BgxJkvpw9NFHc/LJJ5ddRs2yi0SSJGVnwJAkSdkZMCRJUnb9BoyIOCkiHoqI1oh4PCI+WZl/fET8OCK2VX5OHPpyJUlSLShyBaML+F8ppenAbOBPI+J0YDHwk5TSKcBPKtOSJEn9B4yU0rMppX+svH4JaAWmApcDd1ZWuxO4YohqlCRJNeaI7sGIiAZgFvAo8LaU0rPQHUKAydmrkyRJNalwwIiI44DvAjeklF48gu0WRsSGiNjgYCWSJI0OhQJGRBxNd7j4vyml71Vm/yYiplSWTwF29LVtSmlZSqkxpdR44okn5qhZkiQNc0W+RRLA7UBrSumLPRb9ALiq8voq4Pv5y5MkSbWoyFDhc4CPAo9FxKbKvL8AlgD3RMTHgXbgj4akQkmSVHP6DRgppXVAHGbxe/OWI0mSRgJH8pQkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2Y0puwBJUg83nwmd7WVXUVxdPSx6rOwqNAwZMCRpOOlsh6bOsqsorqmu7Ao0TNlFIkmSsjNgSJKk7AwYkiQpOwOGJEnKzoAhSZKyM2BIkqTsDBiSJCk7A4YkScrOgCFJkrIzYEiSpOwcKlwaDmrp+RN19WVXMOzNWbKK7bv2DGjbtrHQsPjeQdcwdcI4Hl58waD3Iw2UAUMaDmrt+RN6Q9t37aFtySUD27iJgW/bQ46QIg2GXSSSJCk7A4YkScrOgCFJkrIzYEiSpOwMGJIkKTsDhiRJys6AIUmSsjNgSJKk7AwYkiQpO0fylCQNicEMmZ6Dw6WXy4AhSRoSgxoyPQOHSy+XXSSSJCk7A4YkScrOgCFJkrIzYEiSpOwMGJIkKbt+A0ZELI+IHRGxpce8pojYHhGbKn/mD22ZkiSplhS5gnEHcFEf829OKc2s/Lkvb1mSJKmW9RswUkprgBeqUIskSRohBnMPxp9FxOZKF8rEbBVJkqSaN9CRPP8P8DkgVX7+HXB1XytGxEJgIUB9ff0ADydJA1RXD011VT1k21igaYAb1/k+mcvUCeNKH81zNA9XPqCAkVL6zf7XEfFV4IdvsO4yYBlAY2NjGsjxJGnAFj1W9UM2LL631CGy1W04fLCXHXDKNKAukoiY0mNyAbDlcOtKkqTRp98rGBFxN3A+MCkiOoC/Bs6PiJl0d5G0AdcOXYmSJKnW9BswUkof6mP27UNQiyRJGiEcyVOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2Y0puwBJI8ecJavYvmtPqTVMnTCOhxdfUGoNw8HUCeNoWHzvkB+nbSyHPc7UCeOG/PgavgwYkrLZvmsPbUsuKbWGanyo1oKqhawmSv871/BkF4kkScrOgCFJkrIzYEiSpOwMGJIkKTsDhiRJys6AIUmSsjNgSJKk7AwYkiQpOwOGJEnKzoAhSZKyM2BIkqTsDBiSJCk7A4YkScrOgCFJkrIzYEiSpOwMGJIkKTsDhiRJys6AIUmSsjNgSJKk7AwYkiQpuzFlFyANiZvPhM72sqsorq6+7AokKSsDhkamznZo6iy7CkkatewikSRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZ9RswImJ5ROyIiC095h0fET+OiG2VnxOHtkxJklRLilzBuAO46KB5i4GfpJROAX5SmZYkSQIKBIyU0hrghYNmXw7cWXl9J3BF3rIkSVItG+g9GG9LKT0LUPk5OV9JkiSp1g35SJ4RsRBYCFBf73DINcuhtyVJR2CgAeM3ETElpfRsREwBdhxuxZTSMmAZQGNjYxrg8VQ2h96WJB2BgXaR/AC4qvL6KuD7ecqRJEkjQZGvqd4NNAO/FxEdEfFxYAlwYURsAy6sTEuSJAEFukhSSh86zKL3Zq5FkiSNEI7kKUmSsjNgSJKk7AwYkiQpOwOGJEnKzoAhSZKyM2BIkqTsDBiSJCm7IX8WiSRV09QJ42hYfG/pNUijnQFD0ojy8OILyi5BEnaRSJKkIWDAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2Y8ouQJKkkWrqhHE0LL639BoeXnxB1Y9rwJAkaYiU8cF+sLICjl0kkiQpOwOGJEnKzoAhSZKyM2BIkqTsDBiSJCk7A4YkScpuUF9TjYg24CXgNaArpdSYoyhJklTbcoyDMTel9HyG/UiSpBHCLhJJkpTdYANGAh6IiJaIWJijIEmSVPsG20UyJ6X0TERMBn4cEVtTSmt6rlAJHgsB6uvrB3k4aXias2QV23ftKbuM0k2dMK7sEiQNE4MKGCmlZyo/d0TECuAcYM1B6ywDlgE0NjamwRxPGq6279pD25JLyi5DkoaNAXeRRMSxETF+/2vgfcCWXIVJkqTaNZgrGG8DVkTE/v18M6X0oyxVSZKkmjbggJFS+mfgnRlrkSRJI4RfU5UkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2RkwJElSdgYMSZKUnQFDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlJ0BQ5IkZWfAkCRJ2Y0puwBJUg2rq4emurKrKK6uHhY9VnYVo4IBQ5I0cLX2YV1LYajG2UUiSZKyM2BIkqTsDBiSJCk7A4YkScrOgCFJkrIzYEiSpOwMGJIkKTsDhiRJys6AIUmSshs5I3nefCZ0tpddRXEOVytJGsFGTsDobIemzrKrKM7haiVJI5hdJJIkKTsDhiRJys6AIUmSsjNgSJKk7AwYkiQpOwOGJEnKzoAhSZKyM2BIkqTsDBiSJCk7A4YkScpu5AwVrlFrzpJVbN+1p9Qapk4YV+rxJRVUV19bj2qo4edWGTBU87bv2kPbkkvKLkNSLai1D+taCkMHsYtEkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGU3qIARERdFxJMR8cuIWJyrKEmSVNsGHDAi4ijgK8DFwOnAhyLi9FyFSZKk2jWYKxjnAL9MKf1zSulfgX8ALs9TliRJqmWDCRhTgad7THdU5kmSpFEuUkoD2zDij4B5KaVrKtMfBc5JKV1/0HoLgYWVyd8DnnyD3U4Cnh9QQSOT7dGb7dGb7dGb7dGb7dGb7dFbrvZ4e0rpxL4WDGao8A7gpB7T04BnDl4ppbQMWFZkhxGxIaXUOIiaRhTbozfbozfbozfbozfbozfbo7dqtMdgukh+DpwSESdHxJuBDwI/yFOWJEmqZQO+gpFS6oqIPwP+P3AUsDyl9Hi2yiRJUs0a1NNUU0r3AfdlqgUKdqWMIrZHb7ZHb7ZHb7ZHb7ZHb7ZHb0PeHgO+yVOSJOlwHCpckiRlV0rA6G+I8ei2tLJ8c0ScVUad1VKgPU6LiOaIeDUibiyjxmoq0B7/vXJebI6IRyLinWXUWS0F2uPySltsiogNEfHuMuqslqKPKIiIsyPitYj4QDXrq7YC58f5EdFZOT82RcRflVFntRQ5PyptsikiHo+In1a7xmoqcH787x7nxpbKv5njsxw8pVTVP3TfEPoU8B+BNwO/AE4/aJ35wP1AALOBR6td5zBrj8nA2cDfAjeWXfMwaI93ARMrry/2/OA4/r27cwawtey6y2yPHuutovsesQ+UXXfJ58f5wA/LrnUYtccE4AmgvjI9uey6y2yPg9a/FFiV6/hlXMEoMsT45cA3Urf1wISImFLtQquk3/ZIKe1IKf0c2FdGgVVWpD0eSSn9tjK5nu4xWEaqIu2xO1XeHYBjgZF8Y1XRRxRcD3wX2FHN4krgIxt6K9IeHwa+l1Jqh+731yrXWE1Hen58CLg718HLCBhFhhgfTcOQj6bftYgjbY+P0321a6Qq1B4RsSAitgL3AldXqbYy9NseETEVWADcVsW6ylL038u5EfGLiLg/Is6oTmmlKNIepwITI2J1RLRExJVVq676Cr+fRsQxwEV0B/MsBvU11QGKPuYd/D+uIuuMFKPpdy2icHtExFy6A8ZIvuegUHuklFYAKyLiPOBzwB8OdWElKdIetwB/nlJ6LaKv1UeUIu3xj3QP57w7IuYDK4FThrqwkhRpjzHA7wPvBcYBzRGxPqX0T0NdXAmO5PPlUuDhlNILuQ5eRsAoMsR4oWHIR4jR9LsWUag9ImIG8DXg4pTSzirVVoYjOj9SSmsi4j9FxKSU0kh87kKR9mgE/qESLiYB8yOiK6W0sioVVle/7ZFSerHH6/si4tZRfn50AM+nlF4GXo6INcA7gZEYMI7k/eODZOwegXK6SIoMMf4D4MrKt0lmA50ppWerXWiVOOR6b/22R0TUA98DPjpC/9fRU5H2+N2ofJpWvnH1ZmCkhq5+2yOldHJKqSGl1AB8B/ifIzRcQLHz43d6nB/n0P2+P2rPD+D7wHsiYkylW+A/A61VrrNaCn2+REQd8Ad0t002Vb+CkQ4zxHhEXFdZfhvdd37PB34JvAL8cbXrrJYi7RERvwNsAN4KvB4RN9B9J/CLh9tvrSp4fvwVcAJwa+V9syuN0IcYFWyP/0p3IN8H7AH+W4+bPkeUgu0xahRsjw8A/yMiuug+Pz44ms+PlFJrRPwI2Ay8DnwtpbSlvKqHzhH8e1kAPFC5qpONI3lKkqTsHMlTkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmDEhENEbEnIjYNwb7HRcSmiPjXiJiUe/+Sho4BQ1IOT6WUZubeaUppT2W/o/n5PFJNMmBIo1xE/H5EPNRj+h0R0TzIfa6sPAr78YhY2MfyhojY0mP6xohoGswxJQ0vZTxNVdLw0gqc2mP6b4CbImJiSum3A9zn1SmlFyJiHPDziPjuCH/qraSDeAVDGuVSSq8AeyNiQuVprBNTSg8CN+9fJyK+doS7/URE/AJYT/fjok/JVrCkmmDAkATwBHAa8DngLyPiIuC0StfFMcDvRsTfRsSK/nYUEecDfwicm1J6J7ARGHvQal30fv85eLmkGmfAkATwOPDHdD9h+WHgeeCulNIXgLOA+1NKnwGKPM65DvhtSumViDgNmN3HOr8BJkfECRHxFuC/ZPktJA0bBgxJ0B0w/gS4qTI9A/hF5fXZwP+rvH6twL5+BIyJiM10XxFZv39BRNwXEf8hpbSP7ns9HgV+CGztuYP96w3wd5E0DHiTpyRSSn8P/H2PWc8D10TE88DpwNLKOBTPFdjXq8DFh1k2v8frpcDS/taTVJsipVR2DZJqWEScBDwC7Mw9FkblWyjNwInAmSmlF3LuX9LQMWBIkqTsvAdDkiRlZ8CQJEnZGTAkSVJ2BgxJkpSdAUOSJGVnwJAkSdkZMCRJUnYGDEmSlN2/AQu71OXBjIa+AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 648x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "counts, bins = np.histogram(v_th_before)\n",
    "fig, ax = plt.subplots(figsize=(9,5))\n",
    "ax.hist(bins[:-1], bins, weights=counts, histtype='step', label='before')\n",
    "counts, bins = np.histogram(v_th_after)\n",
    "ax.hist(bins[:-1], bins, weights=counts, histtype='step', label='after')\n",
    "ax.set_xlabel('$v_{th}$ [a.u.]')\n",
    "ax.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions\n",
    "\n",
    "This concludes this tutorial, as we have seen incorporating *neuron parameters* into the optimisation in Norse mainly requires knowledge of the PyTorch internals. At the moment that optimisation is also only supported for the surrogate gradient implementation and not while using the discretised adjoint implementation. This is not a fundamental limitation though and support for neuron parameter gradients could be added. \n",
    "\n",
    "Several publications have recently explored optimising neuron parameters to both demonstrate that the resulting time constant distributions more closely resembled ones measured in biological cells and advantages in ML applications. As always the challenge to practically demonstrating the advantage of incorporating neuron parameters into the optimisation is doing careful and controlled experiments and choosing the right tasks. While this notebook demonstrates the general technique it does take some shortcuts:\n",
    "\n",
    "- Optimisation works best if the scale of the values to be optimised is in an appropriate range. This is not the case here for the inverse time constant, which has a value of > 100. A more sophisticated approach would take this into account by for example generating this parameter from an appropriately scaled initial value. \n",
    "- Similarly we have chosen a rather naive initialisation of both the threshold and inverse membrane time constant. More careful initialisation would for example involve sampling from a distribution with guaranteed positive samples.\n",
    "- Finally we haven't demonstrated any benefit of jointly optimising neuron parameters and synaptic weights. \n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
